Joeyonng
  • Notebook
  • Pages
  • About
  • Backyard
  1. Deep Learning
  2. 45  Transformer
  • Welcome
  • Notations and Facts
  • Linear Algebra
    • 1  Fields and Spaces
    • 2  Vectors and Matrices
    • 3  Span and Linear Independence
    • 4  Basis and Dimension
    • 5  Linear Map and Rank
    • 6  Inner Product and Norm
    • 7  Orthogonality and Unitary Matrix
    • 8  Complementary Subspaces and Projection
    • 9  Orthogonal Complement and Decomposition
    • 10  SVD and Pseudoinverse
    • 11  Orthogonal and Affine Projection
    • 12  Determinants and Eigensystems
    • 13  Similarity and Diagonalization
    • 14  Normal and Hermitian Matrices
    • 15  Positive Definite Matrices
  • Calculus
    • 16  Derivatives
    • 17  Chain rule
  • Probability and Statistics
    • 18  Probability
    • 19  Random Variables
    • 20  Expectation
    • 21  Common Distributions
    • 22  Moment Generating Function
    • 23  Concentration Inequalities I
    • 24  Convergence
    • 25  Limit Theorems
    • 26  Maximum Likelihood Estimation
    • 27  Bayesian Estimation
    • 28  Expectation-maximization
    • 29  Concentration Inequalities II
  • Learning Theory
    • 30  Statistical Learning
    • 31  Bayesian Classifier
    • 32  Effective Class Size
    • 33  Empirical Risk Minimization
    • 34  Uniform Convergence
    • 35  PAC Learning
    • 36  Rademacher Complexity
  • Machine Learning
    • 37  Linear Discriminant
    • 38  Perceptron
    • 39  Logistic Regression
    • 40  Multi-layer Perceptron
    • 41  Boosting
    • 42  Support Vector Machine
    • 43  Decision Tree
    • 44  Principle Component Analysis
  • Deep Learning
    • 45  Transformer

Table of contents

  • Scaled Dot Product Attention
  • Multi-Head Attention
    • Parallel Implementation
  • Positional Encoding
    • Sinusoidal Positional Encoding
  • Transformer Encoder
  1. Deep Learning
  2. 45  Transformer

45  Transformer

import math

import torch
import torch.nn as nn
def masked_X(X, valid_lens, value=-1e6):
    '''
    Calculate the masked version of `X` based on `valid_lens`. Usually `X` is
    the scores used in dot product attention function with shape (batch_size,
    num_queries, num_keys). Consider the i-th sequence in the batch. To specify
    the j-th query should not attend to the keys after index k, we set
    `valid_lens[i][j] = k`. If `valid_lens` is 1d tensor, then it is treated as
    `valid_lens[i][j] = valid_lens[i]` for all `j`s

    Args:
        X (torch.Tensor): The tensor to be masked. Should have shape of
            (batch_size, num_queries, num_keys).
        valid_lens (torch.Tensor): Specify what entries in X are removed.
            Tensor of shape (batch_size, ) or (batch_size, num_queries).
            If shape (batch_size, ), set `X[i, :, valid_lens[i]:] = value` for
            every `i in range(len(valid_lens))`.
            If (batch_size, num_queries), set `X[i, j, valid_lens[i][j]:] = value`
            for every `i in range(len(valid_lens))` and `j in range(len(valid_lens[i]))`.
        value (float): The value to set for the removed entries in `X`.

    Returns:
        (torch.Tensor): Masked score of shape (batch_size, num_queries, num_keys).
    '''
    # Repeat or reshape `valid_lens` to have length batch_size * num_queries.
    # Each number in `valid_lens` corresponds to the number of valid tokens in
    # each query sequence.
    if valid_lens.dim() == 1:
        valid_lens = torch.repeat_interleave(valid_lens, X.shape[1])
    else:
        valid_lens = valid_lens.reshape(-1)

    # Create a 1d range array [0, ..., num_keys]
    seq_len_range = torch.arange((X.shape[-1]), dtype=torch.float32, device=X.device)

    # Create masks by broadcasting
    masks = seq_len_range[None, :] < valid_lens[:, None]

    masked_X = X.reshape(-1, X.shape[-1])
    masked_X[~masks] = value
    masked_X = masked_X.reshape(X.shape)

    return masked_X

X = torch.rand(2, 2, 4)
valid_lens = torch.tensor([2, 3])
print(masked_X(X, valid_lens))
tensor([[[ 9.7588e-01,  2.7506e-01, -1.0000e+06, -1.0000e+06],
         [ 4.3251e-01,  6.0360e-01, -1.0000e+06, -1.0000e+06]],

        [[ 2.2331e-02,  6.7170e-01,  5.6041e-01, -1.0000e+06],
         [ 6.1315e-02,  4.9694e-02,  4.6259e-01, -1.0000e+06]]])

Scaled Dot Product Attention

The attention function in transformers is a mechanism that calculates a weighted combination of input values to capture dependencies between tokens in a sequence, regardless of their distance.

The most commonly used attention function is scaled dot-product attention:

\mathrm{Attention} (\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax} \left( \frac{ \mathbf{Q} \mathbf{K}^{\top} }{ \sqrt{d_k} } \right) \mathbf{V},

where:

  • \mathbf{Q}: Query matrix. Each row in \mathbf{Q} is a query token.

  • \mathbf{K}: Key matrix. Each row in \mathbf{K} is a key token.

  • \mathbf{V}: Value matrix. Each row in \mathbf{V} is a value token.

  • d_{k}: Dimensionality of the key tokens.

Explanations:

  1. Compute dot products between \mathbf{Q} and \mathbf{K} to get a similarity score. Each row in matrix \mathbf{Q} \mathbf{K}^{\top} contains the similarity scores (dot product) between the i-th query token \mathbf{Q} and all key tokens in \mathbf{K}.

  2. Scale the scores by \sqrt{d_k} to prevent large gradients.

  3. Apply softmax to convert scores into attention weights. The softmax is applied to each row of matrix \frac{\mathbf{Q} \mathbf{K}^{\top}}{\sqrt{d_{k}}}, i.e. each row in \mathrm{softmax} \left( \frac{\mathbf{Q} \mathbf{K}^{\top}}{\sqrt{d_{k}}} \right) sums up 1.

  4. Multiply the weights by \mathbf{V} to produce the attention output.

class ScaledDotProductAttention(nn.Module):
    def __init__(self, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens=None):
        '''
        queries: (batch_size, num_queries, query_dim)
        keys: (batch_size, num_keys, query_dim)
        values: (batch_size, num_keys, value_dim)
        valid_lens: (batch_size, ) or (batch_size, num_queries)
        '''

        # scores: (batch_size, num_queries, num_keys)
        scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(queries.shape[-1])
        if valid_lens is not None:
            scores = masked_X(scores, valid_lens)

        # attention_weights: (batch_size, num_queries, num_keys)
        self.attention_weights = torch.softmax(scores, dim=-1)
        # attentions: (batch_size, num_queries, value_dim)
        attentions = torch.bmm(self.dropout(self.attention_weights), values)

        return attentions

queries = torch.normal(0, 1, (2, 4, 2))
keys = torch.normal(0, 1, (2, 6, 2))
values = torch.normal(0, 1, (2, 6, 4))
valid_lens = torch.tensor([2, 4])

attention = ScaledDotProductAttention(dropout=0.1)
attention.eval()
print(attention(queries, keys, values, valid_lens))
print(attention.attention_weights)
tensor([[[ 0.1430, -0.1285, -0.1106,  0.7467],
         [ 0.2903, -0.1678,  0.3139,  0.9371],
         [ 0.0870, -0.1135, -0.2719,  0.6743],
         [ 0.0252, -0.0970, -0.4500,  0.5944]],

        [[-0.2120, -0.1235, -0.6992,  0.7657],
         [-0.2176, -0.1362, -0.6724,  0.7451],
         [-0.0597, -0.1562, -0.7385,  0.7999],
         [-0.1990, -0.2225, -0.5296,  0.6328]]])
tensor([[[0.4356, 0.5644, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1318, 0.8682, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5511, 0.4489, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.6785, 0.3215, 0.0000, 0.0000, 0.0000, 0.0000]],

        [[0.2134, 0.2993, 0.2441, 0.2431, 0.0000, 0.0000],
         [0.2087, 0.2950, 0.2389, 0.2573, 0.0000, 0.0000],
         [0.2731, 0.2140, 0.2473, 0.2657, 0.0000, 0.0000],
         [0.1981, 0.2416, 0.2108, 0.3494, 0.0000, 0.0000]]])

Multi-Head Attention

In practice, given the same set of queries, keys, and values, we may want our model to combine different aspects of the knowledge in the data, which can be mathematically represented using different subspaces of the same data.

Given n tokens as rows of a matrix \mathbf{X}, they can be projected into vectors in another subspace by using a linear transformation matrix \mathbf{W}

\mathbf{X}' = \mathbf{X} \mathbf{W}.

The idea of multi-head attention is to perform the same attention mechanism on different learnable subspaces of the same set of queries, keys, and values, whose results are then concatenated and linear transformed again to give the information from different aspects of the data

\mathrm{MultiHead} (\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \begin{bmatrix} \mathrm{head}_{1}, \dots, \mathrm{head}_{h} \end{bmatrix} \mathbf{W}_{O},

where

\mathrm{head}_{i} = \mathrm{Attention} (\mathbf{Q} \mathbf{W}_{Q}, \mathbf{K} \mathbf{W}_{K}, \mathbf{V} \mathbf{W}_{V}).

Parallel Implementation

The linear transformations of inputs for each attention layer can be implemented using fully connected (linear) layers. A vanilla implementation of the multi-head attention layer is to create a separate set of linear layers for each attention head and use a for loop to get the outputs from the h heads.

multi_head_outputs = []
for i in range(h):
    W_q = linear(head_hidden_dim)
    W_k = linear(head_hidden_dim)
    W_v = linear(head_hidden_dim)

    output = Attention(W_q(Q), W_k(K), W_q(V))
    multi_head_outputs.append(output)

A parallel version of the same operation can be implemented using a single set of linear layers. To understand this, first observe the following facts.

  1. Batch processing of the inputs is supported in Attention(). That is, the input X in Attention(X, X, X) has a shape of (batch_size, seq_len, hidden_dim).

  2. The h different linear transformations \mathbf{W}_{1}, \dots, \mathbf{W}_{h} \in \mathbb{R}^{n \times d} of the same input \mathbf{X} can be grouped and replaced by a single linear transformation matrix \mathbf{W} \in \mathbb{R}^{n \times dh},

\begin{bmatrix} | & & | \\ \mathbf{X} \mathbf{W}_{1} & \dots & \mathbf{X} \mathbf{W}_{h} \\ | & & | \end{bmatrix} = \mathbf{X} \begin{bmatrix} | & & | \\ \mathbf{W}_{1} & \dots & \mathbf{W}_{h} \\ | & & | \end{bmatrix} = \mathbf{X} \mathbf{W}.

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, hidden_dim, dropout=0.1, bias=False):
        super().__init__()
        self.num_heads = num_heads
        self.attention = ScaledDotProductAttention(dropout)

        # Here we set `hidden_dim / num_heads` as the input embedding dim of
        # queries, keys, and values in each head.
        # The reason we use `hidden_dim` instead of `hidden_dim / num_heads` as
        # the output dim for `W_q`, `W_k`, and `W_v` is to enable the parallel
        # computation for all heads e.g. the i-th set of `hidden_dim / num_heads`
        # outputs is for the i-th head.
        self.W_q = nn.LazyLinear(hidden_dim, bias=bias)
        self.W_k = nn.LazyLinear(hidden_dim, bias=bias)
        self.W_v = nn.LazyLinear(hidden_dim, bias=bias)

        # The input for `W_o` layer is still `hidden_dim` as the `num_heads`
        # number of heads are concatenated before feeding into `W_o` layer.
        self.W_o = nn.LazyLinear(hidden_dim, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        '''
        The forward computaiton of a transformer attention layer.

        Args:
            queries (batch_size, num_queries, query_dim)
            keys (batch_size, num_keys, query_dim)
            values (batch_size, num_keys, value_dim)
            valid_lens (batch_size, ) or (batch_size, num_queries)

        Returns:
            (batch_size, num_queries, hidden_dim)
        '''

        # multi_queries: (batch_size * num_heads, num_queries, hidden_dim / num_heads)
        multi_queries = self.transpose_qkv(self.W_q(queries))
        # multi_keys: (batch_size * num_heads, num_keys, hidden_dim / num_heads)
        multi_keys = self.transpose_qkv(self.W_k(keys))
        # multi_values: (batch_size * num_heads, num_keys, hidden_dim / num_heads)
        multi_values = self.transpose_qkv(self.W_v(values))

        # Repeat each element in `valid_lens` `num_heads` times to align with
        # the shape of `multi_*queries`.
        if valid_lens is not None:
            # multi_valid_lens: (batch_size * num_heads, ) or (batch_size * num_heads, num_queries)
            multi_valid_lens = valid_lens.repeat_interleave(
                self.num_heads, dim=0
            )

        # multi_output: (batch_size * num_heads, num_queries, hidden_dim / num_heads)
        multi_output = self.attention(
            multi_queries, multi_keys, multi_values, valid_lens=multi_valid_lens
        )

        # output: (batch_size, num_queries, hidden_dim)
        output = self.W_o(self.transpose_output(multi_output))

        return output

    def transpose_qkv(self, X):
        '''
        Reshape X for parallel computation of multiple attention heads. Assume
        `X` has shape (batch_size, seq_len, hidden_dim) and `hidden_dim` is
        divisible by `num_heads`. We want to make it `(batch_size * num_heads,
        seq_len, hidden_dim / num_heads)`, so that the self-attention performed
        later is done on `hidden_dim / num_heads` dimension.

        Args:
            X (batch_size, seq_len, hidden_dim)

        Returns:
            (batch_size * num_heads, seq_len, hidden_dim / num_heads)
        '''

        # X: (batch_size, seq_len, num_heads, hidden_dim / num_heads)
        X = X.reshape(X.shape[0], X.shape[1], self.num_heads, -1)
        # X: (batch_size, num_heads, seq_len, hidden_dim / num_heads)
        X = X.permute(0, 2, 1, 3)
        # X: (batch_size * num_heads, seq_len, hidden_dim / num_heads)
        X = X.reshape(-1, X.shape[2], X.shape[3])

        return X

    def transpose_output(self, X):
        '''
        Reverse the operation of `transpose_qkv`.

        Args:
            X (batch_size * num_heads, seq_len, hidden_dim / num_heads)

        Returns:
            (batch_size, seq_len, hidden_dim)
        '''

        # X: (batch_size, num_heads, seq_len, hidden_dim / num_heads)
        X = X.reshape(-1, self.num_heads, X.shape[1], X.shape[2])
        # X: (batch_size, seq_len, num_heads, hidden_dim / num_heads)
        X = X.permute(0, 2, 1, 3)
        # X: (batch_size, seq_len, hidden_dim)
        X = X.reshape(X.shape[0], X.shape[1], -1)

        return X

attention = MultiHeadAttention(2, 4, 0)
queries = torch.ones((2, 4, 4))
keys = torch.ones((2, 6, 4))
values = torch.ones((2, 6, 6))
valid_lens = torch.tensor([3, 2])
print(attention(queries, keys, values, valid_lens).shape)
torch.Size([2, 4, 4])

Positional Encoding

Given a sequence of n tokens as rows of a matrix \mathbf{X} \in \mathbb{R}^{n \times d}, the positional encoding will inject positional information into \mathbf{X} by generating a new matrix \mathbf{X}'

\mathbf{X}' = \mathbf{X} + \mathbf{P}

where \mathbf{P} \in \mathbb{R}^{n \times d} is a positional encoding matrix with each row being a positional encoding vector for each token for \mathbf{X}.

Usually \mathbf{P} should provide two types of positional information.

  • Absolute positional information. This type requires the encoding to provide the positional information that is unique across the entire sequence.

  • Relative positional information. This type requires the encoding to provide the positional information that encodes the relative order of the tokens.

Sinusoidal Positional Encoding

In sinusoidal positional encoding, the positional encoding matrix \mathbf{P} has sine and cosine functions with different periods at odd and even columns, respectively.

Each element p_{i, j} at the i-th row and j-th column in \mathbf{P} is calculated as

p_{i, j} = \begin{aligned} \begin{cases} \sin (\omega_{j} i) & \quad \text{when } j \text{ is even} \\ \cos (\omega_{j} i) & \quad \text{when } j \text{ is odd} \end{cases} \end{aligned}

where

\omega_{j} = \begin{aligned} \begin{cases} 1 \mathbin{/} \left( 10000^{j \mathbin{/} d} \right) & \quad \text{when } j \text{ is even} \\ 1 \mathbin{/} \left( 10000^{(j - 1) \mathbin{/} d} \right) & \quad \text{when } j \text{ is odd}. \end{cases} \end{aligned}

Encoding absolute positional information

The number of unique encodings that sinusoidal positional encoding can represent depends on d.

  • If d = 1, the encodings for the single column is \sin(i) for i = 1, \dots, n and it has a period of \lambda = 2 \pi. Since the sine function will repeat after each period, the positional encoding using a single sine function can represent at most \lfloor \lambda \rfloor = 6 number of tokens.

  • If d = 2, the encodings for the 1st and 2nd columns are \sin(i) and \cos(i) for i = 1, \dots, n, which have the same period \lambda = 2\pi. Since the corresponding the sine and cosine functions for the even and odd columns always have the same period, the number of unique tokens it can represent with an odd d is the same as that with the corresponding even d.

  • If d > 2, the number of unique positions that the sinusoidal positional encoding can achieve is the least common multiples of d \mathbin{/} 2 different periods, which is quite large for a reasonable d.

Encoding relative positional information

For any fixed offset \delta, the encodings at position i + \delta can be expressed as a linear transformation of the encodings at position i. To see this, we can use trigonometric sum identities to rewrite the encodings at position i + \delta:

\sin(\omega_{j} (i + \delta)) = \sin(\omega_{j} i) \cos(\omega_{j} \delta) + \cos(\omega_{j} i) \sin(\omega_j \delta),

\cos(\omega_{j} (i + \delta)) = \cos(\omega_{j} i) \cos(\omega_{j} \delta) - \sin(\omega_{j} i) \sin(\omega_j \delta),

which can be represented using the matrix multiplication

\begin{bmatrix} \sin(\omega_{j} (i + \delta)) \\ \cos(\omega_{j} (i + \delta)) \end{bmatrix} = \begin{bmatrix} \cos(\omega_{j} \delta) & \sin(\omega_{j} \delta) \\ \cos(\omega_{j} \delta) & - \sin(\omega_{j} \delta) \end{bmatrix} \begin{bmatrix} \sin(\omega_{j} i) \\ \cos(\omega_{j} i) \end{bmatrix}.

The positional encoding at i + \delta can be obtained by multiplying the encoding at i with a 2 \times 2 rotation matrix whose values do not depend on the position of the token i, which shows that the encodings at different positions are linearly dependant.

class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, hidden_dim, dropout, max_len=1000):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

        # i: (max_len, 1)
        i = torch.arange(max_len).reshape(-1, 1)
        # two_j: (hidden_dim / 2, )
        two_j = torch.arange(0, hidden_dim, 2)
        # X: (max_len, hidden_dim / 2)
        X = i / torch.pow(10000, two_j / hidden_dim)

        # P is the same for each sequence matrix in the mini-batch.
        # The shape of (1, max_len, hidden_dim) for `P` can be directly
        # broadcasted to (batch_size, max_len, hidden_dim) when added to `X`.
        self.P = torch.zeros((1, max_len, hidden_dim))
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)

    def forward(self, X):
        '''
        X (batch_size, seq_len, hidden_dim)
        '''

        # Since `X` will have a seq_len less than `max_len`, we want to take the
        # first seq_len of `P` when added to `X`.
        X = X + self.P[:, :X.shape[1], :].to(X.device)

        return self.dropout(X)

encoding_dim, num_steps = 32, 60
pos_encoding = SinusoidalPositionalEncoding(encoding_dim, 0)
X = torch.zeros((10, num_steps, encoding_dim))
X_p = pos_encoding(X)
P = pos_encoding.P[:, :X.shape[1], :]
print(P.shape)
torch.Size([1, 60, 32])

Transformer Encoder

class TransformerEncoderBlock(nn.Module):
    def __init__(self, num_heads, hidden_dim, ffn_hidden_dim, dropout, bias=False):
        super().__init__()
        self.attention = MultiHeadAttention(num_heads, hidden_dim, dropout, bias)
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.ffn = nn.Sequential(
            nn.LazyLinear(ffn_hidden_dim, bias=bias),
            nn.ReLU(),
            nn.LazyLinear(hidden_dim, bias=bias)
        )
        self.norm2 = nn.LayerNorm(hidden_dim)

    def forward(self, X, valid_lens):
        '''
        X (batch_size, seq_len, hidden_dim)
        valid_lens (batch_size, ) or (batch_size, seq_len)
        '''

        # Y: (batch_size, seq_len, hidden_dim)
        Y = self.norm1(X + self.attention(X, X, X, valid_lens))
        # Z: (batch_size, seq_len, hidden_dim)
        Z = self.norm2(Y + self.ffn(Y))

        return Z

class TransformerEncoder(nn.Module):
    def __init__(self, num_blocks, num_heads, hidden_dim, ffn_hidden_dim, dropout, bias=False):
        super().__init__()
        self.positional_encoding = SinusoidalPositionalEncoding(hidden_dim, dropout)

        self.attention_blocks = []
        for i in range(num_blocks):
            self.attention_blocks.append(
                TransformerEncoderBlock(num_heads, hidden_dim, ffn_hidden_dim, dropout, bias)
            )

    def forward(self, X, valid_lens):
        '''
        X (batch_size, seq_len, hidden_dim)
        valid_lens (batch_size, ) or (batch_size, seq_len)
        '''

        # X: (batch_size, seq_len, hidden_dim)
        X = self.positional_encoding(X)
        # X: (batch_size, seq_len, hidden_dim)
        for block in self.attention_blocks:
            X = block(X, valid_lens)

        return X

encoder = TransformerEncoder(2, 4, 8, 16, 0.5)
X = torch.ones((2, 4, 8))
valid_lens = torch.tensor([2, 3])
print(encoder(X, valid_lens).shape)
torch.Size([2, 4, 8])
class TransformerDecoderBlock(nn.Module):
    def __init__(self, block_index, num_heads, hidden_dim, ffn_hidden_dim, dropout, bias=False):
        super().__init__()
        self.block_index = block_index
        self.attention1 = MultiHeadAttention(num_heads, hidden_dim, dropout, bias)
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.attention2 = MultiHeadAttention(num_heads, hidden_dim, dropout, bias)
        self.norm2 = nn.LayerNorm(hidden_dim)
        self.ffn = nn.Sequential(
            nn.LazyLinear(ffn_hidden_dim, bias=bias),
            nn.ReLU(),
            nn.LazyLinear(hidden_dim, bias=bias)
        )
        self.norm3 = nn.LayerNorm(hidden_dim)

    def forward(self, X, state):
        '''
        Args:
        - X (torch.Tensor): (batch_size, seq_len, hidden_dim)
        - state (tuple): A tuple of 3 Tensors (enc_outputs, enc_valid_lens,
            dec_outputs). enc_outputs: (batch_size, seq_len, hidden_dim) is
            enc_valid_lens: (batch_size, ) or (batch_size, seq_len)
            dec_outputs (num_blocks, batch_size, num_tokens_so_far, hidden_dim):

        '''
        # enc_output: (batch_size, seq_len, hidden_dim)
        # enc_valid_lens: (batch_size, ) or (batch_size, seq_len)
        # dec_outputs: (num_blocks, batch_size, n, hidden_dim)
        enc_outputs, enc_valid_lens, dec_outputs = state

        # During training, all tokens in any sequence are available, so `X` has
        # shape of (batch_size, seq_len, hidden_dim). The mask is used to ensure
        # that attention is performed on the previous tokens.
        if self.training:
            # dec_valid_lens: (batch_size, seq_len)
            # Each row in `dec_valid_lens` is [1, ..., seq_len], which masks out
            # the upper right diagonal of each score matrix in the batch.
            dec_valid_lens = torch.arange(
                1, X.shape[1] + 1, device=X.device
            ).repeat(X.shape[0], 1)
            Y = self.norm1(X + self.attention1(X, X, X, dec_valid_lens))
        # During prediction, one token is available per call of the function, so
        # `X` has shape of (batch_size, 1, hidden_dim).
        else:
            # prev_X: (batch_size, num_tokens_so_far, hidden_dim)
            if prev_X is None:
                prev_X = X
            else:
                prev_X = torch.cat((dec_outputs[self.block_index], X), dim=1)
            state[2][self.block_index] = prev_X
            Y = self.norm1(X + self.attention1(X, prev_X, prev_X))

        Z = self.norm2(self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens))
        return self.norm3(Z + self.ffn(Z)), state

def TransformerDecoder(nn.Module):
    def __init__(self, num_blocks, num_heads, hidden_dim, ffn_hidden_dim, dropout, bias=False):
        super().__init__()
        self.positional_encoding = SinusoidalPositionalEncoding(hidden_dim, dropout=dropout)
        self.num_blocks = num_blocks

        self.decoder_blocks = []
        for i in range(num_blocks):
            self.decoder_blocks.append(
                TransformerDecoderBlock(i, num_heads, hidden_dim, ffn_hidden, dropout=dropout)
            )

        self.dense = nn.LazyLinear(hidden_dim)

    def forward(self, X, state):
        X = self.positional_encoding(X)
        for block in self.decoder_blocks:
            X, state = block(X, state)

        return self.dense(X), state
44  Principle Component Analysis